{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "\n",
    "## Install dependencies\n",
    "!pip install wget\n",
    "!apt-get install sox libsndfile1 ffmpeg\n",
    "!pip install unidecode\n",
    "\n",
    "# ## Install NeMo\n",
    "BRANCH = 'r1.0.0rc1'\n",
    "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]\n",
    "\n",
    "## Install TorchAudio\n",
    "!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "Who Speaks When? Speaker Diarization is the task of segmenting audio recordings by speaker labels. \n",
    "A diarization system consists of Voice Activity Detection (VAD) model to get the time stamps of audio where speech is being spoken ignoring the background and Speaker Embeddings model to get speaker embeddings on segments that were previously time stamped. These speaker embeddings would then be clustered into clusters based on number of speakers present in the audio recording.\n",
    "\n",
    "In NeMo we support both **oracle VAD** and **non-oracle VAD** diarization. \n",
    "\n",
    "In this tutorial, we shall first demonstrate how to perform diarization with a oracle VAD time stamps (we assume we already have speech time stamps) and pretrained speaker verification model which can be found in tutorial for [Speaker and Recognition and Verification in NeMo](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_recognition/Speaker_Recognition_Verification.ipynb).\n",
    "\n",
    "In [second part](#ORACLE-VAD-DIARIZATION) we show how to perform VAD and then diarization if ground truth timestamped speech were not available (non-oracle VAD). We also have tutorials for [VAD training in NeMo](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/06_Voice_Activiy_Detection.ipynb) and [online offline microphone inference](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/07_Online_Offline_Microphone_VAD_Demo.ipynb), where you can custom your model and training/finetuning on your own data.\n",
    "\n",
    "For demonstration purposes we would be using simulated audio from [an4 dataset](http://www.speech.cs.cmu.edu/databases/an4/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import wget\n",
    "ROOT = os.getcwd()\n",
    "data_dir = os.path.join(ROOT,'data')\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "an4_audio = os.path.join(data_dir,'an4_diarize_test.wav')\n",
    "an4_rttm = os.path.join(data_dir,'an4_diarize_test.rttm')\n",
    "if not os.path.exists(an4_audio):\n",
    "    an4_audio_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav\"\n",
    "    an4_audio = wget.download(an4_audio_url, data_dir)\n",
    "if not os.path.exists(an4_rttm):\n",
    "    an4_rttm_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.rttm\"\n",
    "    an4_rttm = wget.download(an4_rttm_url, data_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot and listen to the audio and visualize the RTTM speaker labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import librosa\n",
    "\n",
    "sr = 16000\n",
    "signal, sr = librosa.load(an4_audio,sr=sr) \n",
    "\n",
    "fig,ax = plt.subplots(1,1)\n",
    "fig.set_figwidth(20)\n",
    "fig.set_figheight(2)\n",
    "plt.plot(np.arange(len(signal)),signal,'gray')\n",
    "fig.suptitle('Reference merged an4 audio', fontsize=16)\n",
    "plt.xlabel('time (secs)', fontsize=18)\n",
    "ax.margins(x=0)\n",
    "plt.ylabel('signal strength', fontsize=16);\n",
    "a,_ = plt.xticks();plt.xticks(a,a/sr);\n",
    "\n",
    "IPython.display.Audio(an4_audio)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We would use [pyannote_metrics](https://pyannote.github.io/pyannote-metrics/) for visualization and score calculation purposes. Hence all the labels in rttm formats would eventually be converted to pyannote objects, we created two helper functions rttm_to_labels (for NeMo intermediate processing) and labels_to_pyannote_object for scoring and visualization format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.parts.speaker_utils import rttm_to_labels, labels_to_pyannote_object"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's load ground truth RTTM labels and view the reference Annotation timestamps visually"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# view the sample rttm file\n",
    "!cat {an4_rttm}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = rttm_to_labels(an4_rttm)\n",
    "reference = labels_to_pyannote_object(labels)\n",
    "print(labels)\n",
    "reference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Speaker Diarization scripts commonly expects two files:\n",
    "1. paths2audio_files : either list of audio file paths or file containing paths to audio files for which we need to perform diarization.\n",
    "2. path2groundtruth_rttm_files (optional): either list of rttm file paths or file containing paths to rttm files (this can be passed if we need to calculate DER rate based on our ground truth rttm files).\n",
    "\n",
    "**Note** we expect audio and corresponding RTTM have **same base name** and the name should be **unique**. \n",
    "\n",
    "For eg: if audio file name is **test_an4**.wav, if provided we expect corresponding rttm file name to be **test_an4**.rttm (note the matching **test_an4** base name)\n",
    "\n",
    "Now let's create paths2audio_files list (or file) for which we need to perform diarization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths2audio_files = [an4_audio]\n",
    "print(paths2audio_files)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly create` path2groundtruth_rttm_files` list (this is optional, and needed for score calculation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path2groundtruth_rttm_files = [an4_rttm]\n",
    "print(path2groundtruth_rttm_files)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ORACLE-VAD DIARIZATION"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Oracle-vad diarization is to compute speaker embeddings from known speech label timestamps rather than depending on VAD output. This step can also be used to run speaker diarization with rttms generated from any external VAD, not just VAD model from NeMo.\n",
    "\n",
    "For it, the first step is to start converting reference audio rttm(vad) time stamps to oracle manifest file. This manifest file would be sent to our speaker diarizer to extract embeddings.\n",
    "For that let's use write_rttm2manifest function, that takes paths2audio_files and paths2rttm_files as arguments "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.parts.speaker_utils import write_rttm2manifest\n",
    "output_dir = os.path.join(ROOT, 'oracle_vad')\n",
    "os.makedirs(output_dir,exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oracle_manifest = os.path.join(output_dir,'oracle_manifest.json')\n",
    "write_rttm2manifest(paths2audio_files=paths2audio_files,\n",
    "                    paths2rttm_files=path2groundtruth_rttm_files,\n",
    "                    manifest_file=oracle_manifest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat {oracle_manifest}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our config file is based on [hydra](https://hydra.cc/docs/intro/). \n",
    "With hydra config, we ask users to provide values to variables that were filled with **???**, these are mandatory fields and scripts expect them for successful runs. And notice some variables were filled with **null** are optional variables. Those could be provided if needed but are not mandatory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf\n",
    "MODEL_CONFIG = os.path.join(data_dir,'speaker_diarization.yaml')\n",
    "if not os.path.exists(MODEL_CONFIG):\n",
    "    config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_recognition/conf/speaker_diarization.yaml\"\n",
    "    MODEL_CONFIG = wget.download(config_url,data_dir)\n",
    "config = OmegaConf.load(MODEL_CONFIG)\n",
    "print(OmegaConf.to_yaml(config))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can perform speaker diarization based on timestamps generated from ground truth rttms rather than generating through VAD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained_speaker_model='speakerdiarization_speakernet'\n",
    "config.diarizer.paths2audio_files = paths2audio_files\n",
    "config.diarizer.path2groundtruth_rttm_files = path2groundtruth_rttm_files\n",
    "config.diarizer.out_dir = output_dir #Directory to store intermediate files and prediction outputs\n",
    "\n",
    "config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
    "# Ignoring vad we just need to pass the manifest file we created\n",
    "config.diarizer.speaker_embeddings.oracle_vad_manifest = oracle_manifest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.models import ClusteringDiarizer\n",
    "oracle_model = ClusteringDiarizer(cfg=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# And lets diarize\n",
    "oracle_model.diarize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With DER 0 -> means it clustered speaker embeddings correctly. Let's view "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat {output_dir}/pred_rttms/an4_diarize_test.rttm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_labels = rttm_to_labels(output_dir+'/pred_rttms/an4_diarize_test.rttm')\n",
    "hypothesis = labels_to_pyannote_object(pred_labels)\n",
    "hypothesis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VAD DIARIZATION"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this method we compute VAD time stamps using NeMo VAD model on `paths2audio_files` and then use these time stamps of speech label to find speaker embeddings followed by clustering them into num of speakers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we proceed let's look at the speaker diarization config, which we would be depending up on for vad computation\n",
    "and speaker embedding extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(OmegaConf.to_yaml(config))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As can be seen most of the variables in config are self explanatory \n",
    "with VAD variables under vad section and speaker related variables under speaker embeddings section. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To perform VAD based diarization we can ignore `oracle_vad_manifest` in `speaker_embeddings` section for now and needs to fill up the rest. We also needs to provide pretrained `model_path` of vad and speaker embeddings .nemo models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained_vad = 'vad_marblenet'\n",
    "pretrained_speaker_model = 'speakerdiarization_speakernet'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note in this tutorial, we use the VAD model MarbleNet-3x2 introduced and published in [ICASSP MarbleNet](https://arxiv.org/pdf/2010.13886.pdf). You might need to tune on dev set similar to your dataset if you would like to improve the performance.\n",
    "\n",
    "And the speakerNet-M-Diarization model achieves 7.3% confusion error rate on CH109 set with oracle vad. This model is trained on voxceleb1, voxceleb2, Fisher, SwitchBoard datasets. So for more improved performance specific to your dataset, finetune speaker verification model with a devset similar to your test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = os.path.join(ROOT,'outputs')\n",
    "config.diarizer.paths2audio_files = paths2audio_files\n",
    "config.diarizer.path2groundtruth_rttm_files = path2groundtruth_rttm_files\n",
    "config.diarizer.out_dir = output_dir # Directory to store intermediate files and prediction outputs\n",
    "config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
    "\n",
    "#Here we use our inhouse pretrained NeMo VAD \n",
    "config.diarizer.vad.model_path = pretrained_vad\n",
    "config.diarizer.vad.window_length_in_sec = 0.15\n",
    "config.diarizer.vad.shift_length_in_sec = 0.01\n",
    "config.diarizer.vad.threshold = 0.8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we passed all the variables we needed lets initialize the clustering model with above config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.models import ClusteringDiarizer\n",
    "sd_model = ClusteringDiarizer(cfg=config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And Diarize with single line of code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sd_model.diarize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As can be seen, we first performed VAD, then with the timestamps created in `{output_dir}/vad_outputs` by VAD we calculated speaker embeddings (`{output_dir}/speaker_outputs/embeddings/`) which are then clustered using spectral clustering. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To generate VAD predicted time step. We perform VAD inference to have frame level prediction &#8594; (optional: use decision smoothing) &#8594; given `threshold`,  write speech segment to RTTM-like time stamps manifest.\n",
    "\n",
    "we use vad decision smoothing (87.5% overlap median) as described [here](https://github.com/NVIDIA/NeMo/blob/r1.0.0rc1/nemo/collections/asr/parts/vad_utils.py)\n",
    "\n",
    "you can also tune the threshold on your dev set. Use this provided [script](https://github.com/NVIDIA/NeMo/blob/r1.0.0rc1/scripts/voice_activity_detection/vad_tune_threshold.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# VAD predicted time stamps\n",
    "from nemo.collections.asr.parts.vad_utils import extract_labels, plot\n",
    "\n",
    "plot(paths2audio_files[0],\n",
    "     'outputs/vad_outputs/overlap_smoothing_output_median_0.875/an4_diarize_test.median', \n",
    "     path2groundtruth_rttm_files[0],\n",
    "     threshold=config.diarizer.vad.threshold)\n",
    "print(f\"threshold: {config.diarizer.vad.threshold}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Predicted outputs are written to `output_dir/pred_rttms` and see how we predicted along with VAD prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat outputs/pred_rttms/an4_diarize_test.rttm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_labels = rttm_to_labels('outputs/pred_rttms/an4_diarize_test.rttm')\n",
    "hypothesis = labels_to_pyannote_object(pred_labels)\n",
    "hypothesis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Storing and Restoring models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can save the whole config and model parameters in a single .nemo and restore from it anytime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oracle_model.save_to(os.path.join(output_dir,'diarize.nemo'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Restore from saved model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del oracle_model\n",
    "import nemo.collections.asr as nemo_asr\n",
    "restored_model = nemo_asr.models.ClusteringDiarizer.restore_from(os.path.join(output_dir,'diarize.nemo'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ADD ON - ASR "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "IPython.display.Audio(an4_audio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quartznet = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name=\"QuartzNet15x5Base-En\")\n",
    "for fname, transcription in zip(paths2audio_files, quartznet.transcribe(paths2audio_files=paths2audio_files)):\n",
    "  print(f\"Audio in {fname} was recognized as: {transcription}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
